import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import boto3
import datetimeWhy EfficientNetV2-B0?
I first heard about EfficientNet when version 2 first came out. During that time and even until now, EfficientNet is one of if not the best convolutional neural networks in terms or speed of training and accuracy. In other word: it is incredibly efficient.
I wanted to start off with the EfficientNetV2-S model which is has 22 million parameters but I found out training the S model take roughly 3x the training time so I decided to go with the B0 model which only has 5 million parameters.
The Code
The code can be seperated into five sections: - imports - setting up global variables - getting up the data - creating the model - compiling and fitting the model
We will start with the imports. It is standard tensorflow and keras libraries; boto3 for aws integration and datetime for logging.
Global Variables and Mixed Precision
Next is setting global variables and setting mixed_precision to mixed_float16. Modern GPUs (those with tensor cores) can drastically improve the time to do matrix multiplication and perform mixed-precision multiplication very efficiently. So this allows us to start with 16-bit precision for most of the calculation which frees up VRAM and also helps with training speed. Later we will increase the precision back to 32-bit.
IMG_SIZE = 512
BATCH_SIZE = 32
DATA_DIR = './dataset/'
NUM_CLASSES = 3595
PREVIOUS_EPOCHS = 0
EPOCHS_TO_TRAIN = 25
TOTAL_EPOCHS = PREVIOUS_EPOCHS + EPOCHS_TO_TRAIN
tf.keras.mixed_precision.set_global_policy('mixed_float16')
print("Mixed precision enabled.")Loading and Preparing the Data
Here we find our data and split it into training and validation and we prefetch each batch one batch ahead. Originally I set the buffer_size=AUTOTUNE but it resulted in my system running out of memory.
train_ds = tf.keras.utils.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset="training",
seed=1234,
image_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE,
)
val_ds = tf.keras.utils.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset="validation",
seed=1234,
image_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE,
)
# AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=1)
val_ds = val_ds.prefetch(buffer_size=1)Building the Model Architecture
This is perhaps the most important part of the code. We get our EfficientNetv2-B0 model with pre-trained weights using the imagenet dataset. We remove the top layer and set trainable = FALSE. Basically, we freeze all the layers in the model so that the weights from the imagenet wouldn’t change during the training.
Now we create our own layer to replace the top layer that we removed. We preprocesses the image and run them through our model. The training = false option makes the batch normalization layers use inference mode.
GlobalAveragePooling2D converts our features into a 1-dimension vector for classification. Dropout helps up with overfitting by dropping 30% of our data (setting 30% of inputs to 0 randomly). The Dense layer is our output layer. We set the number of classes we are trying to classify, softmax sets the output score to a set of probability distribution for all the classes, and finally, like we mentioned before, we set the precision back to 32-bit.
# data_augmentation = keras.Sequential([
# layers.RandomRotation(0.1),
# layers.RandomZoom(0.15),
# layers.RandomContrast(0.2),
# ], name="augmentation")
# def augment_data(image, label):
# return data_augmentation(image, training=True), label
base_model = keras.applications.EfficientNetV2B0(
weights='imagenet',
include_top=False,
input_shape=(IMG_SIZE, IMG_SIZE, 3),
)
base_model.trainable = False
inputs = keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = keras.applications.efficientnet_v2.preprocess_input(inputs)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax", dtype='float32')(x)
model = keras.Model(inputs, outputs)Compiling and Training the Model
keras.optimizers.schedules.ExponentialDecay creates a dynamic learning schedule so instead of having a fixed learning rate of something like 1e-5, the dynamic learning rate decreases the learning rate so that after many batches the model takes smaller and smaller steps to fine tune their weights.
We compile the model and tells it our learning scedule using adam as our optimization algorithm. sparse_categorical_crossentropy sets our loss function and, since we have multi-class categorical data, this seemed like a good choice.
Then we are onto our callbacks. These are functions that run after the model has finished training. ModelCheckpoint saves the model to file after every epoch but it will only save the model if the model’s performance improves over the last one. EarlyStopping stops training automatically after patience=5 epochs without any improvements to the validation loss value. After it stops training, the best weights are automatically restored. I will go into this later.
TensorBoard is a tool from tensorflow that can help visualize data about the training process. This includes things like GPU and CPU usage. The problem is I forgot to include this callback function to the model.fit call so I wasn’t able to use this tool.
Finally we start our actual training using our training and validation data and setting our callback functions.
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=1e-3,
decay_steps=10000,
decay_rate=0.9)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
checkpoint_cb = keras.callbacks.ModelCheckpoint(
"scarlet_violet_model_1.keras",
save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(
patience=5, restore_best_weights=True
)
log_dir = "./logs/profile/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1,
profile_batch='5,15'
)
model.fit(
train_ds,
epochs=TOTAL_EPOCHS,
# initial_epoch=PREVIOUS_EPOCHS,
validation_data=val_ds,
callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_callback],
)Training Results
Here is my training output:
Epoch 1/25
17975/17975 [==============================] - 2821s 156ms/step - loss: 2.8047 - accuracy: 0.5668 - val_loss: 1.1696 - val_accuracy: 0.8127
Epoch 2/25
17975/17975 [==============================] - 2806s 156ms/step - loss: 1.0823 - accuracy: 0.8073 - val_loss: 0.8373 - val_accuracy: 0.8571
Epoch 3/25
17975/17975 [==============================] - 2807s 156ms/step - loss: 0.7964 - accuracy: 0.8529 - val_loss: 0.7192 - val_accuracy: 0.8733
Epoch 4/25
17975/17975 [==============================] - 2789s 155ms/step - loss: 0.6654 - accuracy: 0.8746 - val_loss: 0.6566 - val_accuracy: 0.8827
Epoch 5/25
17975/17975 [==============================] - 2728s 152ms/step - loss: 0.5847 - accuracy: 0.8882 - val_loss: 0.6203 - val_accuracy: 0.8880
Epoch 6/25
17975/17975 [==============================] - 2808s 156ms/step - loss: 0.5315 - accuracy: 0.8975 - val_loss: 0.5927 - val_accuracy: 0.8915
Epoch 7/25
17975/17975 [==============================] - 2805s 156ms/step - loss: 0.4903 - accuracy: 0.9051 - val_loss: 0.5743 - val_accuracy: 0.8944
Epoch 8/25
17975/17975 [==============================] - 2781s 155ms/step - loss: 0.4628 - accuracy: 0.9102 - val_loss: 0.5596 - val_accuracy: 0.8967
Epoch 9/25
17975/17975 [==============================] - 2807s 156ms/step - loss: 0.4381 - accuracy: 0.9145 - val_loss: 0.5492 - val_accuracy: 0.8981
Epoch 10/25
17975/17975 [==============================] - 2783s 155ms/step - loss: 0.4205 - accuracy: 0.9180 - val_loss: 0.5404 - val_accuracy: 0.8998
Epoch 11/25
17975/17975 [==============================] - 2809s 156ms/step - loss: 0.4072 - accuracy: 0.9202 - val_loss: 0.5335 - val_accuracy: 0.9011
Epoch 12/25
17975/17975 [==============================] - 2805s 156ms/step - loss: 0.3955 - accuracy: 0.9228 - val_loss: 0.5279 - val_accuracy: 0.9015
Epoch 13/25
17975/17975 [==============================] - 2807s 156ms/step - loss: 0.3871 - accuracy: 0.9240 - val_loss: 0.5264 - val_accuracy: 0.9027
Epoch 14/25
17975/17975 [==============================] - 2804s 156ms/step - loss: 0.3744 - accuracy: 0.9263 - val_loss: 0.5264 - val_accuracy: 0.9027
Epoch 15/25
17975/17975 [==============================] - 2805s 156ms/step - loss: 0.3739 - accuracy: 0.9262 - val_loss: 0.5264 - val_accuracy: 0.9027
Epoch 16/25
17975/17975 [==============================] - 2808s 156ms/step - loss: 0.3731 - accuracy: 0.9268 - val_loss: 0.5264 - val_accuracy: 0.9027
Epoch 17/25
17975/17975 [==============================] - 2807s 156ms/step - loss: 0.3742 - accuracy: 0.9264 - val_loss: 0.5264 - val_accuracy: 0.9027
Epoch 18/25
17975/17975 [==============================] - 2810s 156ms/step - loss: 0.3744 - accuracy: 0.9267 - val_loss: 0.5264 - val_accuracy: 0.9027
Uploading model to S3...
Upload complete.
As you can see, starting at epoch 13 the val_loss and val_accuracy remained the same. Even though I set it to run for 25 epochs, the training stopped because of my callback function that stops after 5 epochs if there are no improvements.
Analyzing the Results
Looking at the training logs, we can observe several interesting patterns:
Training Progress
- Epoch 1: The model starts with an accuracy of 56.68% and quickly learns basic patterns
- Epochs 1-5: Rapid improvement phase - accuracy jumps from 56.68% to 88.82%
- Epochs 6-12: Steady improvement - accuracy climbs from 89.75% to 92.28%
- Epochs 13-18: Plateau phase - accuracy hovers around 92.6-92.7%
Validation Performance
The validation accuracy reached 90.27%, which is quite impressive considering: - We’re classifying among 3,595 different classes - Each class represents a specific Pokemon card - The model has never seen these validation images during training
Early Stopping Behavior
The early stopping callback kicked in after epoch 18 because: - The validation loss stopped improving after epoch 13 - It remained at 0.5264 for 5 consecutive epochs (patience=5) - This indicates the model had reached its capacity with the current architecture and frozen base
What the Numbers Tell Us
The final metrics show: - Training accuracy: 92.67% - The model learned the training data well - Validation accuracy: 90.27% - Good generalization with only ~2.4% gap - The small gap between training and validation suggests our augmentation strategy worked well to prevent overfitting
Future Considerations
- I want to retrain the model using a large model like the EfficientNetV2-S or even the M model using a more powerful machine.
- I want to unfreeze the model after training and traing some of the earlier layers.
- I want to perform validation on my model with real world photos.
- The plan is to have photos of every pokemon card in this series to validate the model’s accuracy.
- I want to analyze and evaluate the model’s performance which I, unfortunately, did not have time for in this project
Potential Improvements
Model Architecture
- Unfreeze base layers: After initial training, gradually unfreeze the EfficientNet layers for fine-tuning
- Try larger models: EfficientNetV2-S or V2-M might capture more subtle differences between cards
- Add attention mechanisms: Could help the model focus on distinguishing features like card numbers or specific artwork details
Training Strategy
- Progressive unfreezing: Start with frozen base, then gradually unfreeze layers
- Learning rate scheduling: More aggressive decay once validation plateaus
- Longer training: The model might benefit from more epochs with a lower learning rate
Data Considerations
- Augmentation balance: Current augmentations are quite aggressive - might try more conservative settings
- Class imbalance: Some sets have more cards than others - could use weighted loss
- Validation strategy: K-fold cross-validation would give better performance estimates
Real-World Testing
The ultimate test will be how well this model performs on actual photos of Pokemon cards taken with a phone camera under various lighting conditions and angles. This is something I plan to explore in future work.